反向传播如何实现?
在神经网络的训练过程中,我们使用反向传播(backpropagation)算法来计算权重矩阵 WWW 的梯度,并用这些梯度来更新 WWW。这个更新过程是为了使模型的预测误差最小化。反向传播的一个核心部分就是计算每一层的权重矩阵相对于损失函数的导数(梯度)。
假设我们有一个线性层,其计算公式是: y=xW+by = xW + by=xW+b 其中:
在反向传播中,我们需要计算损失函数 L\mathcal{L}L 对权重矩阵 WWW 的导数,即 ∂L∂W\frac{\partial \mathcal{L}}{\partial W}∂W∂L。
损失函数的梯度:
首先,我们有损失函数 L\mathcal{L}L 对输出 yyy 的梯度 ∂L∂y\frac{\partial \mathcal{L}}{\partial y}∂y∂L,这个梯度由上一层反向传播给我们,记为 grad
,其形状为 (\text{batch_size}, \text{num_out})。
链式法则:
根据链式法则,损失函数对权重矩阵 WWW 的梯度可以分解为:
∂L∂W=∂L∂y⋅∂y∂W\frac{\partial \mathcal{L}}{\partial W} = \frac{\partial \mathcal{L}}{\partial y} \cdot \frac{\partial y}{\partial W}∂W∂L=∂y∂L⋅∂W∂y
计算 ∂y∂W\frac{\partial y}{\partial W}∂W∂y:
我们知道:
y=xWy = xWy=xW
所以,输出 yyy 对权重矩阵 WWW 的导数 ∂y∂W\frac{\partial y}{\partial W}∂W∂y 是输入 xxx:
具体地,假设 xxx 是一个单一的输入向量(形状为 (\text{num_in},)),则 yyy 对 WWW 的导数是:
∂yi∂Wjk=xj对于所有的 i,j,k\frac{\partial yi}{\partial W{jk}} = x_j \quad \text{对于所有的 } i, j, k∂Wjk∂yi=xj对于所有的 i,j,k
这意味着 yyy 对 WWW 的导数是输入 xxx 本身。
矩阵形式的梯度计算:
现在,我们将这些扩展到整个批次。对于每个样本 xxx 和对应的输出梯度 grad
,我们需要计算:
∂L∂W=xT⋅grad\frac{\partial \mathcal{L}}{\partial W} = x^T \cdot \text{grad}∂W∂L=xT⋅grad
其中:
grad
的形状为 (\text{batch_size}, \text{num_out})矩阵乘法 xT⋅gradx^T \cdot \text{grad}xT⋅grad 的结果是一个形状为 (\text{num_in}, \text{num_out}) 的矩阵,这正是我们需要的权重矩阵 WWW 的梯度。
假设我们有以下数据:
grad
:
grad=(0.10.20.30.40.50.6)\text{grad} = \begin{pmatrix} 0.1 & 0.2 \ 0.3 & 0.4 \ 0.5 & 0.6 \end{pmatrix}grad=⎝⎛0.10.30.50.20.40.6⎠⎞
形状为 (3,2)(3, 2)(3,2)
计算 xT⋅gradx^T \cdot \text{grad}xT⋅grad:
xT=(147258369)x^T = \begin{pmatrix} 1 & 4 & 7 \ 2 & 5 & 8 \ 3 & 6 & 9 \end{pmatrix}xT=⎝⎛123456789⎠⎞
xT⋅grad=(147258369)⋅(0.10.20.30.40.50.6)=(1⋅0.1+4⋅0.3+7⋅0.51⋅0.2+4⋅0.4+7⋅0.62⋅0.1+5⋅0.3+8⋅0.52⋅0.2+5⋅0.4+8⋅0.63⋅0.1+6⋅0.3+9⋅0.53⋅0.2+6⋅0.4+9⋅0.6)=(6.07.27.59.09.010.8)x^T \cdot \text{grad} = \begin{pmatrix} 1 & 4 & 7 \ 2 & 5 & 8 \ 3 & 6 & 9 \end{pmatrix} \cdot \begin{pmatrix} 0.1 & 0.2 \ 0.3 & 0.4 \ 0.5 & 0.6 \end{pmatrix} = \begin{pmatrix} 1 \cdot 0.1 + 4 \cdot 0.3 + 7 \cdot 0.5 & 1 \cdot 0.2 + 4 \cdot 0.4 + 7 \cdot 0.6 \ 2 \cdot 0.1 + 5 \cdot 0.3 + 8 \cdot 0.5 & 2 \cdot 0.2 + 5 \cdot 0.4 + 8 \cdot 0.6 \ 3 \cdot 0.1 + 6 \cdot 0.3 + 9 \cdot 0.5 & 3 \cdot 0.2 + 6 \cdot 0.4 + 9 \cdot 0.6 \end{pmatrix} = \begin{pmatrix} 6.0 & 7.2 \ 7.5 & 9.0 \ 9.0 & 10.8 \end{pmatrix}xT⋅grad=⎝⎛123456789⎠⎞⋅⎝⎛0.10.30.50.20.40.6⎠⎞=⎝⎛1⋅0.1+4⋅0.3+7⋅0.52⋅0.1+5⋅0.3+8⋅0.53⋅0.1+6⋅0.3+9⋅0.51⋅0.2+4⋅0.4+7⋅0.62⋅0.2+5⋅0.4+8⋅0.63⋅0.2+6⋅0.4+9⋅0.6⎠⎞=⎝⎛6.07.59.07.29.010.8⎠⎞
最后,我们对结果除以 batch_size
333 得到平均梯度:
13(6.07.27.59.09.010.8)=(2.02.42.53.03.03.6)\frac{1}{3} \begin{pmatrix} 6.0 & 7.2 \ 7.5 & 9.0 \ 9.0 & 10.8 \end{pmatrix} = \begin{pmatrix} 2.0 & 2.4 \ 2.5 & 3.0 \ 3.0 & 3.6 \end{pmatrix}31⎝⎛6.07.59.07.29.010.8⎠⎞=⎝⎛2.02.53.02.43.03.6⎠⎞
这就是最终的权重矩阵 WWW 的梯度。
通过反向传播,我们计算了输入 xxx 对权重矩阵 WWW 的梯度。这一步的目的是确定如何调整权重 WWW 以减少损失。这个过程通过矩阵乘法 xT⋅gradx^T \cdot \text{grad}xT⋅grad 实现,并且结果会对 batch_size
进行平均化。这个计算的背后是链式法则,它使得我们能够将梯度从输出层逐层传播到每个参数。